from wenet.transformer.encoder import TransformerEncoder
speech_transformer = TransformerEncoder(
            input_size=1280,
            output_size=1280,
            attention_heads=4,
            linear_units=2048,
            num_blocks=4,
            dropout_rate=0.1,
            positional_dropout_rate=0.1,
            attention_dropout_rate=0.0,
            input_layer="linear",
            pos_enc_layer_type="abs_pos",
            normalize_before=True
        )

num_params = sum(p.numel() for p in speech_transformer.parameters())
print(num_params/1024/1024)
